"""
@author: chenzhenhua
@project: jf_fashion
@file: fashion.py
@time: 2021/6/9 0009 14:53
@desc:
"""

#from keras_segmentation.pretrained import pspnet_50_ADE_20K , pspnet_101_cityscapes, pspnet_101_voc12
from keras_segmentation.models.all_models import model_from_name

def model_from_checkpoint_path(model_config, latest_weights):

    model = model_from_name[model_config['model_class']](
        model_config['n_classes'], input_height=model_config['input_height'],
        input_width=model_config['input_width'])
    model.load_weights(latest_weights)
    return model

def pspnet_50_ADE_20K():

    model_config = {
            "input_height": 473,
            "input_width": 473,
            "n_classes": 150,
            "model_class": "pspnet_50",
        }

    latest_weights='data/model/pspnet50_ade20k.h5'
    model = model_from_checkpoint_path(model_config, latest_weights)
    return model

def pspnet_101_cityscapes():

    model_config = {
        "input_height": 713,
        "input_width": 713,
        "n_classes": 19,
        "model_class": "pspnet_101",
    }

    latest_weights='data/model/pspnet50_ade20k.h5'
    model = model_from_checkpoint_path(model_config, latest_weights)
    return model

model = pspnet_50_ADE_20K() # load the pretrained model trained on ADE20k dataset
print(model.summary())
#model = pspnet_101_cityscapes() # load the pretrained model trained on Cityscapes dataset

#model = pspnet_101_voc12() # load the pretrained model trained on Pascal VOC 2012 dataset

# load any of the 3 pretrained models

out = model.predict_segmentation(
    inp="data/example_dataset/images_prepped_test/1.png",
    out_fname="data/example_dataset/out2.png"
)